

# ################################################
# Just used in reparameterized MLE + unified approach
# ################################################

estimate_Y <- function(dat, beta_y, beta_m, px){
 
 p = length(beta_y)
 beta_f = beta_y[1:(p-2)]
 w0 = beta_y[p-1]
 wa = beta_y[p]
 
 # dat_comb = rbind(dat, dat_test)
 # idx_train = (1:nrow(dat))
 # idx_test = (nrow(dat)+1):(nrow(dat) + nrow(dat_test))
 
 dat_a0m1 = process_data(dat, a = 0, m = 1) 
 dat_am0 = process_data(dat, a = dat$A, m = 0)
 dat_am1 = process_data(dat, a = dat$A, m = 1)
 dat_a0m = process_data(dat, a = 0, m = dat$M)

 # p(M | A = 0, C)   
 # idx_m = c(1, match(colnames(model.frame(fmla_m))[-1], colnames(dat)))
 idx_m = c(1, match(attributes(beta_m)$names[-1], colnames(dat)))
 p_m1a0 = 1/(1 + exp(-dat_a0m[, idx_m]%*%beta_m))
 p_m0a0 = 1 - p_m1a0
 
 # 1/n { \sum_i \sum_m {f_mAc*p(M=m|A=0,c)} }
 # idx_f = c(1, match(colnames(model.frame(fmla_f))[-1], colnames(dat)))
 idx_f = match(attributes(beta_f)$names, colnames(dat))
 f_m0ac = dat_am0[, idx_f]%*%beta_f
 f_m1ac = dat_am1[, idx_f]%*%beta_f
 sum_train = sum(px*( f_m1ac*p_m1a0 + f_m0ac*p_m0a0 ))
 
 # f(M, A, C)
 f_mac = as.matrix(dat[, idx_f])%*%beta_f
 
 # E[Y | A, C, M] = f - sum_train + w0 + wa*A
 Y_hat = f_mac - sum_train + w0 + wa*dat$A
 
 return(Y_hat)
}


# ################################################
# Prediction in general
# ################################################

compute_mse <- function(dat, idx_test, Y_test, beta, px, opt){
 
 reparam = opt$reparam
 estimator = opt$estimator 
 
 beta_y = beta$beta_y
 beta_m = beta$beta_m
 beta_a = beta$beta_a
 
 dat_a0m0 = process_data(dat, a = 0, m = 0)
 dat_a0m1 = process_data(dat, a = 0, m = 1)
 dat_a1m0 = process_data(dat, a = 1, m = 0)
 dat_a1m1 = process_data(dat, a = 1, m = 1)
 dat_am0 = process_data(dat, a = dat$A, m = 0)
 dat_am1 = process_data(dat, a = dat$A, m = 1)
 dat_a0m = process_data(dat, a = 0, m = dat$M)
 dat_a1m = process_data(dat, a = 1, m = dat$M)
 
 
 # +++++++++++++++++++++++++++++++++
 # G-formula: sum out {M}
 # +++++++++++++++++++++++++++++++++
 if (estimator == "G-formula"){ 
  
  # p*( M | A, C)
  idx_m = c(1, match(attributes(beta_m)$names[-1], colnames(dat)))
  p_m1 = 1/(1 + exp(-as.matrix(dat[idx_test, idx_m])%*%beta_m))
  p_m0 = 1 - p_m1
  
  if (reparam == FALSE){
   # E*[Y | M = m, A, C]
   idx_y = c(1, match(attributes(beta_y)$names[-1], colnames(dat)))
   Yhat_m1 = as.matrix(dat_am1[idx_test, idx_y])%*%beta_y
   Yhat_m0 = as.matrix(dat_am0[idx_test, idx_y])%*%beta_y
  }else{
   # E*[Y | A, C, M=m] = f(m, .) - \sum_i {\sum_m f_i*p(M | a = 0, c_i)} + w0
   Yhat_m1 = estimate_Y(as.data.frame(dat_am1), beta_y, beta_m, px)
   Yhat_m0 = estimate_Y(as.data.frame(dat_am0), beta_y, beta_m, px)
   
   Yhat_m1 = Yhat_m1[idx_test] 
   Yhat_m0 = Yhat_m0[idx_test] 
  }
  
  # E*[Y | A, C] = \sum_M E*[Y | A, C, M] p*(M | A, C)
  Y_hat = Yhat_m1*p_m1 + Yhat_m0*p_m0 
 }

 # +++++++++++++++++++++++++++++++++
 # IPW and AIPW: sum out {A, M}
 # +++++++++++++++++++++++++++++++++
 if (estimator %in% c("IPW", "AIPW") && reparam == FALSE){
  
  # p*( A | C)
  idx_a = c(1, match(attributes(beta_a)$names[-1], colnames(dat)))
  p_a1 = 1/(1 + exp(-as.matrix(dat[idx_test, idx_a])%*%beta_a))
  p_a0 = 1 - p_a1
  
  # p*( M | A, C)
  idx_m = c(1, match(attributes(beta_m)$names[-1], colnames(dat)))
  p_m1 = 1/(1 + exp(-as.matrix(dat[idx_test, idx_m])%*%beta_m))
  p_m0 = 1 - p_m1
  
  # E*[Y | M = m, A, C]
  idx_y = c(1, match(attributes(beta_y)$names[-1], colnames(dat)))
  y_a0m0 = dat_a0m0[idx_test, idx_y]%*%beta_y
  y_a0m1 = dat_a0m1[idx_test, idx_y]%*%beta_y
  y_a1m0 = dat_a1m0[idx_test, idx_y]%*%beta_y
  y_a1m1 = dat_a1m1[idx_test, idx_y]%*%beta_y
  
  # E*[Y | C] = \sum_{A,M} E*[Y | A, C, M] p*(M | A, C)*p(A | C)
  Y_hat = y_a0m0*p_m0*p_a0 + y_a0m1*p_m1*p_a0 + y_a1m0*p_m0*p_a1 + y_a1m1*p_m1*p_a1
 }
 
 # +++++++++++++++++++++++++++++++++
 # Mixed: sum out {A}
 # +++++++++++++++++++++++++++++++++
 if (estimator == "Mixed" && reparam == FALSE){
  
  # p*( A | C)
  idx_a = c(1, match(attributes(beta_a)$names[-1], colnames(dat)))
  p_a1 = 1/(1 + exp(-as.matrix(dat[idx_test, idx_a])%*%beta_a))
  p_a0 = 1 - p_a1
  
  # p*( M | A, C)
  idx_m = c(1, match(attributes(beta_m)$names[-1], colnames(dat)))
  p_m1a0 = 1/(1 + exp(-dat_a0m1[idx_test, idx_m]%*%beta_m))
  p_m1a1 = 1/(1 + exp(-dat_a1m1[idx_test, idx_m]%*%beta_m))
  p_ma0 = p_m1a0 
  p_ma0[dat_test$M == 0] =  1 - p_m1a0[dat_test$M == 0]
  p_ma1 = p_m1a1 
  p_ma1[dat_test$M == 0] =  1 - p_m1a1[dat_test$M == 0]
  
  # E*[Y | M = m, A, C]
  idx_y = c(1, match(attributes(beta_y)$names[-1], colnames(dat)))
  y_a0m = dat_a0m[idx_test, idx_y]%*%beta_y
  y_a1m = dat_a1m[idx_test, idx_y]%*%beta_y
  
  # E*[Y | M, C] = \sum_{A} E*[Y | A, C, M] * { p(M | A, C)*p(A | C) } / { \sum_{A} p(M | A, C) p(A | C) }
  Y_hat = {y_a0m*p_ma0*p_a0 + y_a1m*p_ma1*p_a1} / {p_ma0*p_a0 + p_ma1*p_a1}
 }
 
 MSE = mean((Y_test - Y_hat)^2)

 return(MSE)
}

